# a Pure Exploration problem (pep)
# - a query, as embodied by a correct-answer function istar
# - nanswers: number of possible answers
# - istar: correct answer for feasible μ
# - glrt: value and best response (λ and ξ) to (N, μ) or (w, μ)
# - oracle: characteristic time and oracle weights at μ

```
Best Arm
```

struct BestArm
    dists;    # Array of bounded distrbution
    B;        # Upper bound
end

nanswers(pep::BestArm) = length(pep.dists);
istar(pep::BestArm, μ) = argmax(μ);
getdist(pep::BestArm, k) = pep.dists[k];
getB(pep::BestArm) = pep.B;
isBer(pep::BestArm) = length([dist for dist in pep.dists if typeof(dist) == BernoulliGen]) == length(pep.dists);
long(pep::BestArm) = "BAI for " * (isBer(pep) ? "Bernoulli" : "Bounded") * " in [0," * string(pep.B) * "]";

# Alternative parameter
function alt_λ(pep, samples1, μ1, w1, samplesa, μa, wa)
    B = getB(pep);
    _isBer = isBer(pep);
    @assert μa <= B && μa >= 0 "Domain violation for μa: $(μa) ∈ [0, $(B)]"
    @assert μ1 <= B && μ1 >= 0 "Domain violation for μ1: $(μ1) ∈ [0, $(B)]"
    @assert μa < μ1 "1 should be the best arm: $(μa) < $(μ1)"
    if _isBer
        return (w1 * μ1 + wa * μa) / (w1 + wa);
    else
        try
            res = Optim.optimize(u -> w1 * Kinf_emp(_isBer, samples1, μ1, u, B, false) + wa * Kinf_emp(_isBer, samplesa, μa, u, B, true),
                                 μa, μ1);

            if Optim.converged(res)
                Optim.minimizer(res);
            else
                println("Optimization failed");
                Inf;
            end
        catch e
            println(e);
            Inf;
        end
    end
end

function glrt(pep, w, μ, Xs, bonus="")
    B = getB(pep);
    _isBer = isBer(pep);
    @assert length(size(μ)) == 1
    K = length(μ);

    astar = argmax(μ); # index of best arm among μ

    vals = Inf * ones(K);
    θs = zeros(K);
    for a in 1:K
        if μ[a] < μ[astar]
            θs[a] = alt_λ(pep, Xs[astar], μ[astar], w[astar], Xs[a], μ[a], w[a]);
            vals[a] = w[astar] * Kinf_emp(_isBer, Xs[astar], μ[astar], θs[a], B, false) + w[a] * Kinf_emp(_isBer, Xs[a], μ[a], θs[a], B, true);
            if bonus == "log"
                vals[a] += log(w[a]);
            elseif bonus == "sqrt"
                vals[a] += sqrt(w[a]);
            end
        elseif a != astar
            θs[a] = μ[a];
            vals[a] = 0;
        end
    end
    k = argmin(vals);

    λ = copy(μ);
    λ[astar] = θs[k];
    λ[k] = θs[k];

    vals, (k, λ), (astar, μ);
end


# Solve for x such that d1(μx) + x*da(μx) == v
function X(pep, samples1, μ1, samplesa, μa, v)
    B = getB(pep);
    _isBer = isBer(pep);
    upd_a = Kinf_emp(_isBer, samples1, μ1, μa, B, false); # range of V(x) is [0, upd_a]
    @assert 0 ≤ v ≤ upd_a "0 ≤ $v ≤ $upd_a";
    α = binary_search(
        z -> let uz = alt_λ(pep, samples1, μ1, 1 - z, samplesa, μa, z)
        (1 - z) * Kinf_emp(_isBer, samples1, μ1, uz, B, false) + z * Kinf_emp(_isBer, samplesa, μa, uz, B, true) - (1 - z) * v
        end,
        0, 1, ϵ = upd_a*1e-10);
    α/(1-α), alt_λ(pep, samples1, μ1, 1 - α, samplesa, μa, α);
end

# Oracle solution
function oracle(pep, μs, Xs)
    B = getB(pep);
    _isBer = isBer(pep);
    μstar = maximum(μs);

    if all(μs .== μstar) # yes, this happens
        return Inf, ones(length(μs))/length(μs);
    end

    astar = argmax(μs);

    # determine upper range for subsequent binary search
    hi = minimum(
        Kinf_emp(_isBer, Xs[astar], μs[astar], μs[k], B, false)
        for k in eachindex(μs)
        if k != astar
    );

    val = binary_search(
        z -> sum(
            let ux = X(pep, Xs[astar], μs[astar], Xs[k], μs[k], z)[2];
            Kinf_emp(_isBer, Xs[astar], μs[astar], ux, B, false) / Kinf_emp(_isBer, Xs[k], μs[k], ux, B, true)
            end
            for k in eachindex(μs)
            if k != astar
            ) - 1.0,
        0, hi);

    ws = [(k == astar) ? 1. : X(pep, Xs[astar], μs[astar], Xs[k], μs[k], val)[1] for k in eachindex(μs)];
    Σ = sum(ws);
    Σ / val, ws ./ Σ;
end
